Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] Fix distributed dataloader #8932

Merged
merged 4 commits into from
Aug 16, 2024

Conversation

DesmonDay
Copy link
Contributor

@DesmonDay DesmonDay commented Aug 14, 2024

PR types

Bug fixes

PR changes

Others

Description

  1. Fix distributed dataloader.
  2. Fix rng state loading.
  3. Fix uc unittest.

Distributed dataloader造成hang住的原因:主要针对iterable数据集的热启场景。原来的写法, 数据进程的输入是iterable数据集,从而对应的sampler类型是Infinite类型;而非数据进程的数据输入是None,为None的情况下Paddle的Dataloader会自动设置sampler类型为batch sampler。由于在热启后一般会走跳过数据的逻辑,而跳过数据逻辑主要如下:
截屏2024-08-15 17 08 57

因此数据进程会走入第二个分支,而非数据进程会走入第一个分支,从而走入分支逻辑不一致导致卡住,从卡住时的堆栈可以看出具体问题。
数据进程的卡住见下图:
截屏2024-08-15 17 06 42
非数据进程的卡住见下图:
截屏2024-08-15 17 07 22

Copy link

paddle-bot bot commented Aug 14, 2024

Thanks for your contribution!

Copy link

codecov bot commented Aug 14, 2024

Codecov Report

Attention: Patch coverage is 32.69231% with 35 lines in your changes missing coverage. Please review.

Project coverage is 55.04%. Comparing base (75c7636) to head (2384c4d).
Report is 12 commits behind head on develop.

Current head 2384c4d differs from pull request most recent head fd9ffba

Please upload reports for the commit fd9ffba to get more accurate results.

Files Patch % Lines
paddlenlp/trainer/trainer.py 34.09% 29 Missing ⚠️
paddlenlp/data/dist_dataloader.py 25.00% 6 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8932      +/-   ##
===========================================
- Coverage    55.05%   55.04%   -0.02%     
===========================================
  Files          635      635              
  Lines        99412    99449      +37     
===========================================
+ Hits         54730    54739       +9     
- Misses       44682    44710      +28     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在普通的DataLoader会触发相关的问题吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会,详见PR描述的卡住原因。

batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)

train_sampler = self._get_train_sampler()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的逻辑是is_iterable_dataset,所以下面是非is_iterable_dataset的代码逻辑?

@@ -1694,6 +1726,8 @@ def _load_rng_state(self, checkpoint):

if self.args.use_hybrid_parallel:
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:
if self.args.tensor_parallel_degree <= 1:
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里触发hang住的原因,请文字说明清楚

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行不会触发hang住,只是修了bug。如果非tp但是rng_state里面有tp的种子,加载起来会报错。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在PR描述里解释了

@@ -1398,12 +1398,15 @@ def get_train_dataloader(self):
raise ValueError("We don't need train_dataset when should_load_dataset is False.")

train_dataset = self.train_dataset
if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_dd(train_dataset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_iterable_dataset = self._is_iterable_dataset_dd(train_dataset)
is_iterable_dataset = self._is_iterable_dataset_distributed(train_dataset)

batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
is_iterable_dataset=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以使用一个 additional_args = {} 然后 **additional_args 传参。依然保持 DistDataLoader、 DataLoader 合并

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不太行,因为Paddle的DataLoader不支持可变数量参数输入,除非修改Paddle。

@@ -1694,6 +1726,8 @@ def _load_rng_state(self, checkpoint):

if self.args.use_hybrid_parallel:
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:
if self.args.tensor_parallel_degree <= 1:
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -1132,24 +1132,6 @@ def rerun(self, train_args):
np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol)


@pytest.mark.skipif(True, reason="Skip for None CE")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥删了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

日后如果增加了ignore_merge_optimizer的选项,会和skip_save_model_weight产生冲突,所以删掉了。

@@ -33,6 +33,11 @@ def __len__(self):
return 0


class IterableDummyDataset(paddle.io.IterableDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我在想,是不是可以 数据集那里,自己去构造 Fake的 dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不太理解什么意思,现在这么写我感觉没啥问题?

wawltor
wawltor previously approved these changes Aug 15, 2024
Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@DesmonDay DesmonDay force-pushed the fix_dist_dataloader branch from 2384c4d to fd9ffba Compare August 15, 2024 12:28
@ZHUI ZHUI merged commit e8708ed into PaddlePaddle:develop Aug 16, 2024
9 of 12 checks passed
Copy link
Contributor

@SylarTiaNII SylarTiaNII left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要修复

# For distributed dataloaer.
is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1])
if dist.get_world_size() > 1:
dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NPU不支持bool类型通信,需要兼容

Mangodadada pushed a commit to Mangodadada/PaddleNLP that referenced this pull request Sep 10, 2024
* fix ddloader, fix uc unittest

* update dataloader
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants